{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2025, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}
#include <stdio.h>

#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>

#include "caspar_mappings.h"

namespace cg = cooperative_groups;

// We use shared memory to improve the memory access.
// A smaller block size of 128 allows for larger nodetypes.
constexpr int block_size = 128;

namespace caspar {
{% for nodetype in caslib.exposed_types%}

__global__ __launch_bounds__(block_size, 1)
void {{nodetype.__name__}}_stacked_to_caspar_kernel(const float* const __restrict__ stacked_data,
                          float* const __restrict__ cas_data,
                          const unsigned int cas_stride,
                          const unsigned int cas_offset,
                          const unsigned int num_objects) {
  const unsigned int global_thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  __shared__ float stacked_data_local[block_size*{{Ops.storage_dim(nodetype)}}];

  for (unsigned int target = (blockIdx.x * blockDim.x) * {{Ops.storage_dim(nodetype)}} + threadIdx.x;
       target < min(num_objects, (blockIdx.x + 1) * blockDim.x) * {{Ops.storage_dim(nodetype)}};
       target += blockDim.x) {
      stacked_data_local[target - (blockIdx.x * blockDim.x) * {{Ops.storage_dim(nodetype)}}] = stacked_data[target];
  }

  __syncthreads();

  if (global_thread_idx < num_objects) {
    float data[4] = {0, 0, 0, 0};
    float* stacked_local_ptr = stacked_data_local + threadIdx.x * {{Ops.storage_dim(nodetype)}};
    float* out_ptr;
    {% set ns = namespace(offset=0) %}
    {% for chunk in get_layout(nodetype) %}
    {% for i in range(4) %}
    {% if i < len(chunk) %}
    data[{{i}}] = stacked_local_ptr[{{chunk[i]}}];
    {% endif %}
    {% endfor %}

    out_ptr = cas_data + {{caspar_size(len(chunk))}} * (global_thread_idx + cas_offset) + {{ns.offset}} * cas_stride;
    {% if len(chunk) == 1 %}
    out_ptr[0] = data[0];
    {% elif len(chunk) == 2 %}
    reinterpret_cast<float2*>(out_ptr)[0] = reinterpret_cast<float2*>(data)[0];
    {% elif len(chunk) == 3 %}
    reinterpret_cast<float3*>(out_ptr)[0] = reinterpret_cast<float3*>(data)[0];
    {% elif len(chunk) == 4 %}
    reinterpret_cast<float4*>(out_ptr)[0] = reinterpret_cast<float4*>(data)[0];
    {% endif %}
    {% set ns.offset = ns.offset + caspar_size(len(chunk)) %}
    {% endfor %}
  }
}

__global__ __launch_bounds__(block_size, 1)
void {{nodetype.__name__}}_caspar_to_stacked_kernel(const float* const __restrict__ cas_data,
                             float* const __restrict__ stacked_data,
                             const unsigned int cas_stride,
                             const unsigned int cas_offset,
                             const unsigned int num_objects) {

  const unsigned int global_thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  __shared__ float stacked_data_local[block_size*{{Ops.storage_dim(nodetype)}}];

  if (global_thread_idx < num_objects) {
    float data[4] = {0, 0, 0, 0};
    float* stacked_local_ptr = stacked_data_local + threadIdx.x * {{Ops.storage_dim(nodetype)}};
    {% set ns = namespace(offset=0) %}
    const float* in_ptr;
    {% for chunk in get_layout(nodetype) %}
    in_ptr = cas_data + {{caspar_size(len(chunk))}} * (global_thread_idx + cas_offset) + {{ns.offset}} * cas_stride;
    {% if len(chunk) == 1 %}
    data[0] = in_ptr[0];
    {% elif len(chunk) == 2 %}
    reinterpret_cast<float2*>(data)[0] = reinterpret_cast<const float2*>(in_ptr)[0];
    {% elif len(chunk) == 3 %}
    reinterpret_cast<float3*>(data)[0] = reinterpret_cast<const float3*>(in_ptr)[0];
    {% elif len(chunk) == 4 %}
    reinterpret_cast<float4*>(data)[0] = reinterpret_cast<const float4*>(in_ptr)[0];
    {% endif %}
    {% for i in range(4) %}
    {% if i < len(chunk) %}
    stacked_local_ptr[{{chunk[i]}}] = data[{{i}}];
    {% endif %}
    {% endfor %}
    {% set ns.offset = ns.offset + caspar_size(len(chunk)) %}
    {% endfor %}
  }

  __syncthreads();

  for (unsigned int target = (blockIdx.x * blockDim.x) * {{Ops.storage_dim(nodetype)}} + threadIdx.x;
     target < min(num_objects, (blockIdx.x + 1) * blockDim.x) * {{Ops.storage_dim(nodetype)}};
     target += blockDim.x) {
    stacked_data[target] = stacked_data_local[target - (blockIdx.x * blockDim.x) * {{Ops.storage_dim(nodetype)}}];
  }
}

cudaError_t {{nodetype.__name__}}_stacked_to_caspar(const float* stacked_data,
                                                  float* cas_data,
                                                  const unsigned int cas_stride,
                                                  const unsigned int cas_offset,
                                                  const unsigned int num_objects
                                                  ) {
  const int num_blocks = (num_objects + block_size - 1) / block_size;

  {{nodetype.__name__}}_stacked_to_caspar_kernel<<<num_blocks, block_size>>>(stacked_data, cas_data, cas_stride, cas_offset, num_objects);

  return cudaGetLastError();
}

cudaError_t {{nodetype.__name__}}_caspar_to_stacked(const float* cas_data,
                                                    float* stacked_data,
                                                    const unsigned int cas_stride,
                                                    const unsigned int cas_offset,
                                                    const unsigned int num_objects) {
  const int num_blocks = (num_objects + block_size - 1) / block_size;

  {{nodetype.__name__}}_caspar_to_stacked_kernel<<<num_blocks, block_size>>>(cas_data, stacked_data, cas_stride, cas_offset, num_objects);

  return cudaGetLastError();
}

{% endfor %}
}  // namespace caspar
